package chagine.core.aspect;

import chagine.core.Env;
import chagine.core.ErrorCode;
import chagine.core.annotation.AccessLimit;
import chagine.core.annotation.AccessLimit.ArgsRepeat;
import chagine.core.annotation.AccessLimit.LimitType;
import chagine.core.current.CurrentEnv;
import chagine.core.current.RequestContent;
import chagine.core.util.Assert;
import cn.hutool.extra.servlet.ServletUtil;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.support.atomic.RedisAtomicLong;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.annotation.PostConstruct;
import javax.annotation.Resource;
import javax.servlet.http.HttpServletRequest;
import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;

@Slf4j
@Aspect
@Component
public class AccessLimitAspect {

    private static final String SEPARATOR = ":";
    private static String DEFAULT_KEY;
    private static final String NO_PARAMS_VALUE = "VALUE_IS_EMPTY";

    @Value("${spring.application.name:NO_APP_NAME}")
    private String appName;

    @Resource
    private StringRedisTemplate stringRedisTemplate;

    @Resource
    private ObjectMapper objectMapper;

    @Pointcut("@annotation(chagine.core.annotation.AccessLimit)")
    public void rateLimit() {
    }

    @PostConstruct
    public void init() {
        DEFAULT_KEY = "AccessLimit" + SEPARATOR + appName + SEPARATOR;
    }

    @Around("rateLimit()")
    public Object pointcut(ProceedingJoinPoint point) throws Throwable {
        HttpServletRequest httpServletRequest = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
        String clientIp = ServletUtil.getClientIP(httpServletRequest);
        log.info("请求客户端:{}", clientIp);
        RequestContent content = CurrentEnv.getContent();

        MethodSignature signature = (MethodSignature) point.getSignature();
        Method method = signature.getMethod();
        AccessLimit accessLimit = AnnotationUtils.findAnnotation(method, AccessLimit.class);

        if (accessLimit != null) {
            // 获得redis的Key
            String key = getKey(method, accessLimit, content, httpServletRequest, clientIp);

            long maxRequestCount = accessLimit.maxRequestCount();
            long timeout = accessLimit.timeout();
            TimeUnit timeUnit = accessLimit.timeUnit();
            boolean enableArgsRepeat = accessLimit.argsRepeat() == ArgsRepeat.OPEN;
            String argsValue = enableArgsRepeat ? getArgsJson(point) : null;
            log.info("access limit args_value:{}", argsValue);

            LimitStatus limitStatus = shouldLimited(enableArgsRepeat, key, argsValue, maxRequestCount, timeout, timeUnit);
            Assert.error(limitStatus.limited, ErrorCode.QPS_EXCEEDS_LIMIT, limitStatus.type == 1 ? "请求次数过于频繁，请刷新或请稍后再试" : "请勿短时间提交重复数据，请刷新");
        }
        return point.proceed();
    }

    /**
     * 判断请求限制
     *
     * @param key             redis的基本key
     * @param argsValue       请求参数的字符串
     * @param maxRequestCount 最大请求次数
     * @param timeout         请求超时时间
     * @param timeUnit        请求超时单位
     */
    private LimitStatus shouldLimited(boolean enableArgsRepeat, String key, String argsValue, long maxRequestCount, long timeout, TimeUnit timeUnit) {
        String expiredKey = key + SEPARATOR + "Value" + SEPARATOR;
        String countKey = key + SEPARATOR + "Count";

        // 获得请求次数
        RedisAtomicLong redisAtomicLong = new RedisAtomicLong(countKey, stringRedisTemplate.getConnectionFactory());
        // 优先判断是否超出频率
        boolean requestOver = redisAtomicLong.get() >= maxRequestCount;
        if (requestOver) {
            // 请求次数超频
            LimitStatus status = new LimitStatus();
            long requestNum = redisAtomicLong.get();
            status.limited = requestNum >= maxRequestCount;
            return status;
        }

        // 获取初始值
        String countStatusValue = stringRedisTemplate.opsForValue().get(expiredKey);
        // 判断是否需要初始化
        boolean needInit = StringUtils.isBlank(countStatusValue);

        String saveNewValue = StringUtils.isBlank(argsValue) ? NO_PARAMS_VALUE : argsValue;

        if (enableArgsRepeat) {
            // 参数不存在，则设置
            String saveKey = expiredKey + saveNewValue;
            String sameValue = stringRedisTemplate.opsForValue().get(saveKey);

            // 开启参数重复匹配
            if (StringUtils.isBlank(sameValue)) {
                // 储存值的时间
                stringRedisTemplate.opsForValue().set(saveKey, saveNewValue, timeout, timeUnit);
                if (needInit) {
                    // 初始化值
                    initKeyCount(expiredKey, argsValue, redisAtomicLong, timeout, timeUnit);
                } else {
                    // 计数+1
                    redisAtomicLong.incrementAndGet();
                }
                return new LimitStatus();
            } else {
                // 存在数据
                if (sameValue.equals(argsValue)) {
                    // 提交的请求内容相同则返回错误
                    LimitStatus status = new LimitStatus();
                    status.limited = true;
                    status.type = 2;
                    return status;
                }
            }
        }

        // 计数用的值进行比较
        if (needInit) {
            initKeyCount(expiredKey, saveNewValue, redisAtomicLong, timeout, timeUnit);
        } else {
            // 非初始化自动+1
            redisAtomicLong.incrementAndGet();
        }
        return new LimitStatus();
    }


    /**
     * 获取请求参数
     */
    private String getArgsJson(ProceedingJoinPoint point) {
        Object[] args = point.getArgs();
        Map<String, Object> params = new HashMap<>();
        for (int i = 0; i < args.length; i++) {
            params.put("args_" + i, args[i]);
        }
        try {
            return objectMapper.writeValueAsString(params);
        } catch (JsonProcessingException e) {
            log.error("AccessLimit JSON create ERROR >>>>", e);
            return "JSON_ERROR";
        }
    }

    private String getKey(Method method, AccessLimit accessLimit, RequestContent content, HttpServletRequest httpServletRequest, String clientIp) {
        LimitType type = accessLimit.type();
        String key = DEFAULT_KEY + method.getDeclaringClass().getName() + SEPARATOR + method.getName();
        if (type == LimitType.KEY) {
            key = key + SEPARATOR + accessLimit.key();
        } else if (type == LimitType.USER) {
            Long user = content.getUser();
            key = key + SEPARATOR + "USER_ID" + SEPARATOR + (user == null ? "NO_ID" : user);
        } else if (type == LimitType.TENANT) {
            Long tenant = content.getTenant();
            key = key + SEPARATOR + "TENANT_ID" + SEPARATOR + (tenant == null ? "NO_ID" : tenant);
        } else if (type == LimitType.DEVICE) {
            String device = httpServletRequest.getHeader(Env.HEAD_DEVICE_KEY);
            key = key + SEPARATOR + "DEVICE" + SEPARATOR + (StringUtils.isBlank(device) ? "NO_ID" : device);
        } else if (type == LimitType.NO_SETTING) {
            key = key + SEPARATOR + clientIp;
        }
        return key;
    }

    @Data
    public class LimitStatus {
        /**
         * 是否限制访问
         */
        private boolean limited = false;
        /**
         * 限制类型 1-默认请求过于频繁 2-短时间重复提交
         */
        private int type = 1;
    }

    /**
     * 初始化计数器
     */
    public void initKeyCount(String expiredKey, String argsValue, RedisAtomicLong redisAtomicLong, long timeout, TimeUnit timeUnit) {
        stringRedisTemplate.opsForValue().set(expiredKey, argsValue, timeout, timeUnit);
        redisAtomicLong.set(1);
        redisAtomicLong.expire(timeout, timeUnit);
    }

    public void intOrAddAtomic(boolean needInit, String expiredKey, String argsValue, RedisAtomicLong redisAtomicLong, long timeout, TimeUnit timeUnit) {
        if (needInit) {
            // 初始化值
            initKeyCount(expiredKey, argsValue, redisAtomicLong, timeout, timeUnit);
        } else {
            // 计数+1
            redisAtomicLong.incrementAndGet();
        }
    }
}
